﻿#pragma kernel Clear
#pragma kernel Initialize
#pragma kernel Project
#pragma kernel Apply

#include "MathUtils.cginc"
#include "AtomicDeltas.cginc"
#include "ColliderDefinitions.cginc"
#include "Rigidbody.cginc"

RWStructuredBuffer<int> particleIndices;
StructuredBuffer<int> colliderIndices;
StructuredBuffer<float4> offsets;
RWStructuredBuffer<float> edgeMus;
StructuredBuffer<int2> edgeRanges;
StructuredBuffer<float2> edgeRangeMus;
StructuredBuffer<float> parameters;
RWStructuredBuffer<float> relativeVelocities;
RWStructuredBuffer<float> lambdas;

StructuredBuffer<transform> transforms;
StructuredBuffer<shape> shapes;
RWStructuredBuffer<rigidbody> RW_rigidbodies;

RWStructuredBuffer<float4> RW_positions;

StructuredBuffer<int> deformableEdges;
StructuredBuffer<float4> positions;
StructuredBuffer<float4> prevPositions;
StructuredBuffer<float> invMasses;

StructuredBuffer<inertialFrame> inertialSolverFrame;

// Variables set from the CPU
uint activeConstraintCount;
float stepTime;
float substepTime;
float timeLeft;
int steps;
float sorFactor;

[numthreads(128, 1, 1)]
void Clear (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
   
    if (i >= activeConstraintCount) return;

    int colliderIndex = colliderIndices[i];

    // no collider to pin to, so ignore the constraint.
    if (colliderIndex < 0)
        return;

    int rigidbodyIndex = shapes[colliderIndex].rigidbodyIndex;

    if (rigidbodyIndex >= 0)
    {
        int orig;
        InterlockedExchange(RW_rigidbodies[rigidbodyIndex].constraintCount, 0, orig);
    }
}

bool IsEdgeValid(int edgeIndex, int nextEdgeIndex, float mix)
{
    return (mix < 0) ? deformableEdges[nextEdgeIndex * 2 + 1] == deformableEdges[edgeIndex * 2] :
                       deformableEdges[nextEdgeIndex * 2] == deformableEdges[edgeIndex * 2 + 1];
}

bool ClampToRange(int i, int edgeIndex, inout float mix)
{
    bool clamped = false;
    if (edgeIndex == edgeRanges[i].x && mix < edgeRangeMus[i].x)
    {
        mix = edgeRangeMus[i].x;
        clamped = true;
    }
    if (edgeIndex == edgeRanges[i].y && mix > edgeRangeMus[i].y)
    {
        mix = edgeRangeMus[i].y;
        clamped = true;
    }
    return clamped;
}

[numthreads(128, 1, 1)]
void Initialize (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
   
    if (i >= activeConstraintCount) return;

    int edgeIndex = particleIndices[i];
    int colliderIndex = colliderIndices[i];

    // if no collider or edge, ignore the constraint.
    if (edgeIndex < 0 || colliderIndex < 0)
        return;

    int rigidbodyIndex = shapes[colliderIndex].rigidbodyIndex;

    if (rigidbodyIndex >= 0)
    {
        InterlockedAdd(RW_rigidbodies[rigidbodyIndex].constraintCount, 1);
    }

    float frameEnd = stepTime * steps;
    float substepsToEnd = timeLeft / substepTime;

    // calculate time adjusted compliances
    float compliance = parameters[i * 5] / (substepTime * substepTime);

    int p1 = deformableEdges[edgeIndex * 2];
    int p2 = deformableEdges[edgeIndex * 2 + 1];
    int edgeCount = max(0, edgeRanges[i].y - edgeRanges[i].x + 1);

    // express pin offset in world space:
    float4 worldPinOffset = transforms[colliderIndex].TransformPoint(offsets[i]);
    float4 predictedPinOffset = worldPinOffset;
    
    if (rigidbodyIndex >= 0)
    {
        // predict offset point position using rb velocity at that point (can't integrate transform since position != center of mass)
        float4 velocityAtPoint = GetRigidbodyVelocityAtPoint(rigidbodies[rigidbodyIndex],inertialSolverFrame[0].frame.InverseTransformPoint(worldPinOffset), 
                                                             asfloat(linearDeltasAsInt[rigidbodyIndex]), 
                                                             asfloat(angularDeltasAsInt[rigidbodyIndex]), inertialSolverFrame[0]);

        predictedPinOffset = IntegrateLinear(predictedPinOffset, inertialSolverFrame[0].frame.TransformVector(velocityAtPoint), frameEnd);
    }

    // transform pinhole position to solver space for constraint solving:
    float4 solverPredictedOffset = inertialSolverFrame[0].frame.InverseTransformPoint(predictedPinOffset);

    // get current edge data:
    float mix = 0;
    float4 particlePosition1 = lerp(prevPositions[p1], positions[p1], substepsToEnd);
    float4 particlePosition2 = lerp(prevPositions[p2], positions[p2], substepsToEnd);
    float edgeLength = length(particlePosition1 - particlePosition2) + EPSILON;
    NearestPointOnEdge(particlePosition1, particlePosition2, solverPredictedOffset, mix, false);

    // calculate current relative velocity between rope and pinhole:
    float velocity = (mix - edgeMus[i]) / substepTime * edgeLength; // vel = pos / time.
    relativeVelocities[i] = velocity;

    // apply motor force:
    float targetAccel = (parameters[i * 5 + 2] - velocity) / substepTime; // accel = vel / time.
    float maxAccel = parameters[i * 5 + 3] * max(lerp(invMasses[p1], invMasses[p2], mix), EPSILON); // accel = force / mass. Guard against inf*0
    velocity += clamp(targetAccel, -maxAccel, maxAccel) * substepTime; 

    // calculate new position by adding motor acceleration:
    float corrMix = edgeMus[i] + velocity * substepTime / edgeLength;

    // apply artificial friction by interpolating predicted position and corrected position.
    mix = lerp(mix, corrMix, parameters[i * 5 + 1]);

    // move to an adjacent simplex if needed
    if (!ClampToRange(i, edgeIndex, mix) && (mix < 0 || mix > 1))
    {
        bool clampOnEnd = parameters[i * 5 + 4] > 0.5f;

        // calculate distance we need to travel along simplex chain:
        float distToTravel = length(particlePosition1 - particlePosition2) * (mix < 0 ? -mix : mix - 1);

        int nextEdgeIndex;
        for (int k = 0; k < 10; ++k)
        {
            // calculate index of next edge:
            nextEdgeIndex = edgeRanges[i].x + (int)nfmod((mix < 0 ? edgeIndex - 1 : edgeIndex + 1) - edgeRanges[i].x, edgeCount);

            // see if it's valid
            if (!IsEdgeValid(edgeIndex, nextEdgeIndex, mix))
            {
                // disable constraint if needed
                if (!clampOnEnd) { particleIndices[i] = -1; return; }

                // otherwise clamp to end:
                mix = saturate(mix);
                break;
            }

            // advance to next edge:
            edgeIndex = nextEdgeIndex;

            p1 = deformableEdges[edgeIndex*2];
            p2 = deformableEdges[edgeIndex*2 + 1];
            particlePosition1 = lerp(prevPositions[p1], positions[p1], substepsToEnd);
            particlePosition2 = lerp(prevPositions[p2], positions[p2], substepsToEnd);
            edgeLength = length(particlePosition1 - particlePosition2) + EPSILON;

            // stop if we reached target edge:
            if (distToTravel <= edgeLength)
            {
                mix = mix < 0 ? 1 - saturate(distToTravel / edgeLength) : saturate(distToTravel / edgeLength);
                ClampToRange(i, edgeIndex, mix);
                break;
            }

            // stop if we reached end of range:
            if (ClampToRange(i, edgeIndex, mix))
                break;

            distToTravel -= edgeLength;
        }
    }

    // store new position along edge:
    edgeMus[i] = mix; 
    particleIndices[i] = edgeIndex;
}

[numthreads(128, 1, 1)]
void Project (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= activeConstraintCount) return;

    int edgeIndex = particleIndices[i];
    int colliderIndex = colliderIndices[i];

    // if no collider or edge, ignore the constraint.
    if (edgeIndex < 0 || colliderIndex < 0)
        return;

    float frameEnd = stepTime * steps;
    float substepsToEnd = timeLeft / substepTime;

    // calculate time adjusted compliances
    float compliance = parameters[i * 5] / (substepTime * substepTime);

    int p1 = deformableEdges[edgeIndex * 2];
    int p2 = deformableEdges[edgeIndex * 2 + 1];

    // get current edge data:
    float mix = edgeMus[i];
    float4 particlePosition1 = lerp(prevPositions[p1], positions[p1], substepsToEnd);
    float4 particlePosition2 = lerp(prevPositions[p2], positions[p2], substepsToEnd);
    float4 projection = lerp(particlePosition1, particlePosition2, mix);

    // express pin offset in world space:
    float4 worldPinOffset = transforms[colliderIndex].TransformPoint(offsets[i]);
    float4 predictedPinOffset = worldPinOffset;

    float rigidbodyLinearW = 0;
    float rigidbodyAngularW = 0;
    
    int rigidbodyIndex = shapes[colliderIndex].rigidbodyIndex;
    if (rigidbodyIndex >= 0)
    {
        rigidbody rb = rigidbodies[rigidbodyIndex];

        // predict offset point position using rb velocity at that point (can't integrate transform since position != center of mass)
        float4 velocityAtPoint = GetRigidbodyVelocityAtPoint(rigidbodies[rigidbodyIndex],inertialSolverFrame[0].frame.InverseTransformPoint(worldPinOffset), 
                                                             asfloat(linearDeltasAsInt[rigidbodyIndex]), 
                                                             asfloat(angularDeltasAsInt[rigidbodyIndex]), inertialSolverFrame[0]);

        predictedPinOffset = IntegrateLinear(predictedPinOffset, inertialSolverFrame[0].frame.TransformVector(velocityAtPoint), frameEnd);
        
        // calculate linear and angular rigidbody effective masses (mass splitting: multiply by constraint count)
        rigidbodyLinearW = rb.inverseMass * rb.constraintCount; 
        rigidbodyAngularW = RotationalInvMass(rb.inverseInertiaTensor,
                                              worldPinOffset - rb.com,
                                              normalizesafe(inertialSolverFrame[0].frame.TransformPoint(projection) - predictedPinOffset)) * rb.constraintCount;
    }
   
    // transform pinhole position to solver space for constraint solving:
    predictedPinOffset = inertialSolverFrame[0].frame.InverseTransformPoint(predictedPinOffset);

    float4 gradient = projection - predictedPinOffset;
    float constraint = length(gradient);
    float4 gradientDir = gradient / (constraint + EPSILON);

    float lambda = (-constraint - compliance * lambdas[i]) / (lerp(invMasses[p1], invMasses[p2], mix) + rigidbodyLinearW + rigidbodyAngularW + compliance + EPSILON);
    lambdas[i] += lambda;
    float4 correction = lambda * gradientDir;

    float baryScale =  BaryScale(float4(1 - mix, mix, 0, 0));

    AddPositionDelta(p1, correction * baryScale * invMasses[p1] * (1 - mix) / substepsToEnd);
    AddPositionDelta(p2, correction * baryScale * invMasses[p2] * mix / substepsToEnd);

    if (rigidbodyIndex >= 0)
    {
        ApplyImpulse(rigidbodyIndex,
                     -correction / frameEnd,
                     inertialSolverFrame[0].frame.InverseTransformPoint(worldPinOffset),
                     inertialSolverFrame[0].frame);
    }   
}

[numthreads(128, 1, 1)]
void Apply (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
   
    if (i >= activeConstraintCount) return;

    int edgeIndex = particleIndices[i];
    if (edgeIndex < 0) return;

    int p1 = deformableEdges[edgeIndex * 2];
    int p2 = deformableEdges[edgeIndex * 2 + 1];

    ApplyPositionDelta(RW_positions, p1, sorFactor);
    ApplyPositionDelta(RW_positions, p2, sorFactor);
}